import os
os.environ["CUDA_HOME"] = "/usr/local/cuda"
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    # CodeLlamaTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

print("Using %d %s GPUs." % (torch.cuda.device_count(), torch.cuda.get_device_name()))
DataNum = 20
num_train_epochs = 200
# data = "bash"
# data = "combined"
# data = "clean"
data = 'simple'
# data = 'year'
data_files = f'./data/llama_{data}_{DataNum}.jsonl'
train_dataset = load_dataset('json', data_files=data_files, split="train")

model_size = '7b'
ck = 12000
is_recover = 0
model_name = f"codellama/CodeLlama-{model_size}-Instruct-hf"
# model_name = f"../finetune/llama-2-{model_size}-simple_{DataNum}_lora/checkpoint-8000"

print(f'fine-tuning {model_name}.')
# new_model = f"llama-2-{model_size}-simple_{data}_{DataNum}_lora"
new_model = f"llama-2-{model_size}-{data}_{DataNum}_lora"
print(f'saving to {new_model}')
device_map="balanced"
lora_r = 64
lora_alpha = 16
lora_dropout = 0.1
output_dir = f"./{new_model}"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
fp16 = False
bf16 = False
if model_size == '7b':
    per_device_train_batch_size = 1
else:
    per_device_train_batch_size = 1

print(f"per_device_train_batch_size {per_device_train_batch_size}")
# per_device_eval_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
max_grad_norm = 0.3
learning_rate = 2e-4
weight_decay = 0.001
# optim = "paged_adamw_32bit"
lr_scheduler_type = "constant"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 2000
logging_steps = 100
max_seq_length = None
packing = False

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map
)
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
with torch.no_grad():
    model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
#TrainingArguments
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    # optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)
# Train model

ckpoint = f"llama-2-{model_size}-simple_{DataNum}_lora/checkpoint-{ck}"
if is_recover:
    trainer.train(ckpoint)
else:
    trainer.train()

# Save trained model
trainer.model.save_pretrained(f'./{new_model}/final/')